""" DiffBC Implementation """
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import numpy as np

from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim

from diffgro.environments.collect_dataset import get_skill_embed
from diffgro.diffbc.planner import DiffBCPlanner
from diffgro.utils import print_b


class DiffBC:
    def __init__(
        self,
        env: gym.Env,
        planner: DiffBCPlanner,
        verbose: bool = False,
    ):
        self.env = env
        self.planner = planner.policy
        self.verbose = verbose
        
        self._setup()

    def _setup(self) -> None:
        self.obs_dim = get_flattened_obs_dim(self.env.observation_space)
        self.act_dim = get_act_dim(self.env.action_space)
        self.task = get_skill_embed(None, self.env.env_name).reshape(1, -1)    

    def predict(self, obs: np.ndarray, deterministic: bool = True):
        # add batch dimension
        obs = obs.reshape((-1,) + obs.shape)

        # 1. inference planner
        plan, _ = self.planner._predict(self.task, obs, deterministic)
        act = plan.squeeze()
            
        act = np.array(act.copy())
        return act, None, {}
